[PyTorch] Add distributed Muon optimizer#2920
[PyTorch] Add distributed Muon optimizer#2920vcherepanov-nv wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds
Confidence Score: 5/5Safe to merge; all findings are P2 suggestions with no blocking correctness or security issues. All findings are P2 (style/hardening): missing eps validation, del teardown ordering, and import-time NUM_PROCS. The core optimizer math is correct, distributed normalization is equivalent to the full-matrix reference, and previously discussed issues (closure/enable_grad, global_shape scaling) are properly handled in this version. transformer_engine/pytorch/optimizers/muon.py — eps validation gap and del teardown ordering. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[step called] --> B{closure?}
B -- yes --> C[enable_grad + call closure]
B -- no --> D[iterate param groups]
C --> D
D --> E{p.grad is None?}
E -- yes --> F[skip]
E -- no --> G[resolve + validate partition_dim]
G --> H{decoupled weight decay?}
H -- yes --> I[p *= 1 - lr * wd]
H -- no, wd != 0 --> J[grad += wd * p]
I --> K[momentum_buffer.lerp_ grad]
J --> K
K --> L{nesterov?}
L -- yes --> M[update = grad.lerp momentum_buffer momentum]
L -- no --> N[update = momentum_buffer ref]
M --> O[_orthogonalize update]
N --> O
O --> P[clone + maybe transpose]
P --> Q[distributed_normalize_p2_ all_reduce norm]
Q --> R[newton_schulz distributed]
R --> S[maybe untranspose]
S --> T[scale by get_muon_scale_factor * extra]
T --> U[p += -lr * orth_update]
U --> V[return loss]
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| def step(self, closure=None): | ||
| """Perform a single optimization step.""" | ||
| loss = None | ||
| if closure is not None: | ||
| loss = closure() | ||
|
|
There was a problem hiding this comment.
Closure called inside
@torch.no_grad(), preventing gradient computation
closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() |
| scale_mode: str, | ||
| extra_scale_factor: float, | ||
| eps: float, | ||
| ) -> torch.Tensor: | ||
| global_shape = [grad.size(0), grad.size(1)] | ||
| global_shape[partition_dim] *= world_size |
There was a problem hiding this comment.
Reference
global_shape incorrectly scales an already-full tensor
_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.
The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.
| if mode == "unit_rms_norm": | ||
| return (size_out / size_in) ** 0.5 |
There was a problem hiding this comment.
unit_rms_norm mode can divide by zero when size_in == 0
(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.
| if group["nesterov"]: | ||
| update = grad.lerp(momentum_buffer, group["momentum"]) | ||
| else: | ||
| update = momentum_buffer |
There was a problem hiding this comment.
Non-Nesterov
update is an alias to momentum_buffer, not a copy
update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
skyw
left a comment
There was a problem hiding this comment.
I'd advice NOT to expose it in public API. Keeping it in test only if that is the purpose.
Having an optimizer with most code copied invites fragmentation.
Before this, all optimizer TE provides are more optimized fused version. I'd say a highly optimized Fused Muon with similar concept can be justified, but would need more consideration because it has more dependencies on other part of the training pipeline than elementwise optimizers.
| on tensor-parallel parameter shards. The local parameter shard must represent a | ||
| partition of a logical 2D matrix across the provided NCCL process group. | ||
|
|
||
| Args: |
There was a problem hiding this comment.
Q: Does TE use numpy style docstring instead of Google style?
|
|
||
| def __init__( | ||
| self, | ||
| params: Iterable[torch.nn.Parameter | dict], |
There was a problem hiding this comment.
Nit: The type here doesn't match PyTorch internal. Should be fine for the purpose of this class.
| scale_mode: MuonScaleT = "spectral", | ||
| extra_scale_factor: float = 1.0, | ||
| process_group: Optional[dist.ProcessGroup] = None, | ||
| partition_dim: int = 1, |
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
| if num_ns_steps < 1: | ||
| raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") | ||
| if partition_dim not in (0, 1): |
There was a problem hiding this comment.
Q: Does this class intend to support non-distributed case? partition_dim would be -1 in TE in such case.
|
|
||
| if process_group is None: | ||
| if not dist.is_initialized(): | ||
| raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") |
There was a problem hiding this comment.
Same question above regarding single GPU support.
| if process_group is None: | ||
| if not dist.is_initialized(): | ||
| raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") | ||
| process_group = dist.group.WORLD |
There was a problem hiding this comment.
Suggestion: This silent behavior is dangerous. If user forgot to pass the correct TP group, wrong group will be used.
| eps: float, | ||
| ) -> torch.Tensor: | ||
| self._validate_param(grad, partition_dim) | ||
| world_size = dist.get_world_size(self.process_group) |
There was a problem hiding this comment.
Some suggestion as above. The silent behavior of None process group falling back to default is dangerous. (Understand it is from PyTorch for historical reasons)
| global_shape[partition_dim] *= world_size | ||
|
|
||
| orth_grad = grad.clone() | ||
| transposed = partition_dim == 0 |
There was a problem hiding this comment.
Attn: This is from common Row and Column wise tensor parallelism in most LLM. It would be sub optimal for anything other than that. Add comment if the assumption is made.
The idea was to give something to users, who use TE, but not Megatron-LM. By fragmentation you mean that we want to encourage everyone to use Megatron-LM? Or that the optimizer being relatively thin thing on top of newton_schulz call, and the users should have no trouble creating it themselves? I don't think we gain anything by putting it into tests, since we already have tests for newton_schulz call. So we need to decide whether we want this PR, or should abandon it altogether. @cyanguwa |
Fragmentation means there will be different flavor of muon in emerging optimizer and TE, also a lot of copied code. TE can have stalled feature when emerging optimizer updates. Megatron-LM will always have its own version because there are implementation specific things need to be hooked together. For example, how QKV is implemetned, or fused swighlu. |
There was a problem hiding this comment.
Should we move newton_schulz.py to this directory? Also, how do we expect Megatron to call us for this functionality? Thanks.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Description
Add a distributed Muon optimizer, based on newton_schulz orthogonalization
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: